18 实践课-知识库检索工具开发与标注数据完善

知识库检索工具开发与标注数据完善

关联:索引

术语小抄(初学者版)

先修要求与环境(与 17 对齐)


  1. 输入校验:空问题/超长问题直接拒绝(这是质量门槛)。
  2. 检索证据:向量化 query,top-k 检索拿到候选 chunk。
  3. 结果判定:根据分数与阈值决定“能回答/需要澄清/无结果”。
  4. 生成回答:优先基于证据片段“引用式回答”,并给出来源字段(可追溯)。

本默认设置(与 17 的示例一致):

这意味着:

四、代码模板:知识库检索问答工具(可直接运行)

功能目标:

from __future__ import annotations

import argparse
import json
import os
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional

import faiss
import numpy as np
from sentence_transformers import SentenceTransformer

_MODEL_CACHE: Dict[str, SentenceTransformer] = {}

@dataclass(frozen=True)
class Chunk:
    chunk_id: str
    text: str
    metadata: Dict[str, str]

def get_model(model_name: str) -> SentenceTransformer:
    m = _MODEL_CACHE.get(model_name)
    if m is None:
        m = SentenceTransformer(model_name)
        _MODEL_CACHE[model_name] = m
    return m

def load_store(store_dir: Path) -> tuple[faiss.Index, List[Chunk]]:
    index_path = store_dir / "index.faiss"
    chunks_path = store_dir / "chunks.jsonl"

    if not index_path.is_file():
        raise FileNotFoundError(f"index not found: {index_path}")
    if not chunks_path.is_file():
        raise FileNotFoundError(f"chunks not found: {chunks_path}")

    index = faiss.read_index(str(index_path))
    lines = chunks_path.read_text(encoding="utf-8").splitlines()
    chunks = [Chunk(**json.loads(ln)) for ln in lines if ln.strip()]
    return index, chunks

def embed_query(model: SentenceTransformer, query: str) -> np.ndarray:
    vec = model.encode(
        [query],
        batch_size=1,
        show_progress_bar=False,
        convert_to_numpy=True,
        normalize_embeddings=True,
    )
    return np.ascontiguousarray(vec.astype("float32"))

def retrieve(index: faiss.Index, chunks: List[Chunk], query_vec: np.ndarray, *, top_k: int) -> List[dict]:
    scores, ids = index.search(query_vec, top_k)
    out: List[dict] = []
    for score, i in zip(scores[0].tolist(), ids[0].tolist()):
        if i < 0:
            continue
        c = chunks[i]
        out.append({"score": float(score), "chunk_id": c.chunk_id, "text": c.text, "metadata": c.metadata})
    return out

def pick_answer(hits: List[dict], *, score_threshold: float) -> tuple[Optional[str], List[dict]]:
    passed = [h for h in hits if float(h.get("score", 0.0)) >= score_threshold]
    if not passed:
        return None, []

    top = passed[0]
    meta = top.get("metadata") or {}
    citation = " | ".join(
        [x for x in [meta.get("source", ""), meta.get("doc_type", ""), meta.get("section", ""), meta.get("rule_id", ""), meta.get("alarm_code", "")] if x]
    )
    snippet = (top.get("text") or "").strip().replace("\n", " ")

    answer = f"根据知识库命中片段:{snippet}"
    if citation:
        answer = f"{answer}\n来源:{citation}"

    brief_hits: List[dict] = []
    for h in passed[:3]:
        m = h.get("metadata") or {}
        brief_hits.append(
            {
                "score": float(h["score"]),
                "chunk_id": str(h.get("chunk_id", "")),
                "source": m.get("source", ""),
                "doc_type": m.get("doc_type", ""),
                "section": m.get("section", ""),
                "rule_id": m.get("rule_id", ""),
                "alarm_code": m.get("alarm_code", ""),
                "text_preview": (h.get("text") or "")[:160].replace("\n", " "),
            }
        )
    return answer, brief_hits

def qa(
    query: str,
    *,
    store_dir: str,
    model_name: str,
    top_k: int,
    score_threshold: float,
) -> dict:
    trace_id = uuid.uuid4().hex[:8]
    q = (query or "").strip()

    if not q:
        return {"ok": False, "error": {"code": "INPUT_EMPTY", "message": "query is empty"}, "trace_id": trace_id}
    if len(q) > 300:
        return {"ok": False, "error": {"code": "INPUT_TOO_LONG", "message": "query too long (max 300)"}, "trace_id": trace_id}

    try:
        index, chunks = load_store(Path(store_dir))
    except Exception as e:
        return {"ok": False, "error": {"code": "STORE_LOAD_FAILED", "message": str(e)}, "trace_id": trace_id}

    try:
        model = get_model(model_name)
        qv = embed_query(model, q)
        hits = retrieve(index, chunks, qv, top_k=top_k)
    except Exception as e:
        return {"ok": False, "error": {"code": "RETRIEVE_FAILED", "message": str(e)}, "trace_id": trace_id}

    answer, brief_hits = pick_answer(hits, score_threshold=score_threshold)
    if not answer:
        top1 = hits[0] if hits else None
        top1_score = float(top1["score"]) if top1 else None
        top1_meta = (top1.get("metadata") or {}) if top1 else {}
        return {
            "ok": True,
            "answer": "未在知识库中找到足够相似的片段。请补充关键词(设备型号/告警码/规则编号/品级条件)或降低阈值后再试。",
            "hits": [],
            "trace_id": trace_id,
            "meta": {
                "top_k": top_k,
                "score_threshold": score_threshold,
                "top1_score": top1_score,
                "top1_source": top1_meta.get("source", ""),
            },
        }

    return {"ok": True, "answer": answer, "hits": brief_hits, "trace_id": trace_id, "meta": {"top_k": top_k, "score_threshold": score_threshold}}

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--query", required=True)
    p.add_argument("--store-dir", default=os.environ.get("STORE_DIR", "./faiss_store"))
    p.add_argument("--model-name", default=os.environ.get("MODEL_NAME", "paraphrase-multilingual-MiniLM-L12-v2"))
    p.add_argument("--top-k", type=int, default=5)
    p.add_argument("--score-threshold", type=float, default=0.45)
    return p.parse_args()

def main() -> None:
    project_dir = Path(__file__).resolve().parent
    os.chdir(project_dir)

    args = parse_args()
    result = qa(
        args.query,
        store_dir=args.store_dir,
        model_name=args.model_name,
        top_k=args.top_k,
        score_threshold=args.score_threshold,
    )
    print(json.dumps(result, ensure_ascii=False, indent=2))

if __name__ == "__main__":
    main()

解释与自检要点:

1)把阈值 score_threshold(命令行参数 --score-threshold)分别设为 0.30/0.45/0.60,对同一问题跑三次,记录:命中条目数、top-1 来源是否正确、是否出现噪声命中。
提示:低阈值更“能答”,但更容易答错;高阈值更“谨慎”,但更容易无结果。

2)将输出格式改成“固定结构”,让任何结果都包含:oktrace_idanswerhits
提示:这是工具契约要求,方便测试脚本断言字段存在。

标注数据不是直接“塞进向量库”,必须先转成“可检索的知识片段”:

三、代码模板:把苹果分拣标注数据转成 chunk 并合并入库(可直接运行)

说明:

from __future__ import annotations

import argparse
import json
import os
import uuid
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Dict, List, Optional

import numpy as np
import faiss
from sentence_transformers import SentenceTransformer

@dataclass(frozen=True)
class Chunk:
    chunk_id: str
    text: str
    metadata: Dict[str, str]

def load_chunks_jsonl(path: Path) -> List[Chunk]:
    lines = path.read_text(encoding="utf-8").splitlines()
    return [Chunk(**json.loads(ln)) for ln in lines if ln.strip()]

def save_chunks_jsonl(path: Path, chunks: List[Chunk]) -> None:
    path.write_text("\n".join(json.dumps(asdict(c), ensure_ascii=False) for c in chunks), encoding="utf-8")

def embed_texts(model: SentenceTransformer, texts: List[str]) -> np.ndarray:
    vec = model.encode(texts, batch_size=32, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
    return np.ascontiguousarray(vec.astype("float32"))

def load_index_for_append(index_path: Path, vec_dim: int) -> faiss.Index:
    if not index_path.is_file():
        raise FileNotFoundError(f"index not found: {index_path}")
    index = faiss.read_index(str(index_path))
    if index.d != vec_dim:
        raise ValueError(f"index dim mismatch: index.d={index.d} vs vec_dim={vec_dim}")
    return index

def annotation_to_chunks(records: List[dict]) -> List[Chunk]:
    out: List[Chunk] = []
    for r in records:
        fruit_id = str(r.get("fruit_id", "")).strip()
        if not fruit_id:
            continue

        grade = str(r.get("grade", "")).strip()
        diameter_mm = str(r.get("diameter_mm", "")).strip()
        defect = str(r.get("defect", "")).strip()
        decision = str(r.get("decision", "")).strip()
        reason = str(r.get("reason", "")).strip()
        rule_id = str(r.get("rule_id", "")).strip()

        text = (
            f"苹果分拣标注样本:fruit_id={fruit_id}。"
            f"条件:果径={diameter_mm}mm,瑕疵={defect}。"
            f"结论:品级={grade},去向={decision}。"
            f"理由:{reason}。"
        )
        meta: Dict[str, str] = {
            "data_type": "annotation",
            "source": "apple_sorting_annotations",
            "fruit_id": fruit_id,
        }
        if grade:
            meta["grade"] = grade
        if rule_id:
            meta["rule_id"] = rule_id
        out.append(Chunk(chunk_id=f"ANN-{fruit_id}", text=text, metadata=meta))
    return out

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--store-dir", default=os.environ.get("STORE_DIR", "./faiss_store"))
    p.add_argument("--model-name", default=os.environ.get("MODEL_NAME", "paraphrase-multilingual-MiniLM-L12-v2"))
    p.add_argument("--input-json", default="")
    return p.parse_args()

def main() -> None:
    project_dir = Path(__file__).resolve().parent
    os.chdir(project_dir)

    args = parse_args()
    trace_id = uuid.uuid4().hex[:8]

    store_dir = Path(args.store_dir)
    index_path = store_dir / "index.faiss"
    chunks_path = store_dir / "chunks.jsonl"

    if not chunks_path.is_file():
        raise FileNotFoundError(f"chunks not found: {chunks_path}")

    base_chunks = load_chunks_jsonl(chunks_path)
    existing_ids = {c.chunk_id for c in base_chunks}

    if args.input_json:
        records = json.loads(Path(args.input_json).read_text(encoding="utf-8"))
        if not isinstance(records, list):
            raise ValueError("--input-json must be a JSON array of objects")
        demo_annotations: List[dict] = records
    else:
        demo_annotations = [
            {"fruit_id": "A001", "diameter_mm": 82, "defect": "无明显瑕疵", "grade": "A", "decision": "A口", "reason": "果径达标且瑕疵低", "rule_id": "R-APPLE-01"},
            {"fruit_id": "A002", "diameter_mm": 76, "defect": "轻微擦伤<3%", "grade": "B", "decision": "B口", "reason": "果径达标且轻微瑕疵", "rule_id": "R-APPLE-01"},
            {"fruit_id": "A003", "diameter_mm": 74, "defect": "瑕疵约4%", "grade": "C", "decision": "C口", "reason": "瑕疵偏高或果径不足", "rule_id": "R-APPLE-01"},
            {"fruit_id": "A004", "diameter_mm": 80, "defect": "霉斑", "grade": "D", "decision": "复检口", "reason": "例外条款:霉斑需复检", "rule_id": "R-APPLE-EX"},
        ]
    ann_chunks_all = annotation_to_chunks(demo_annotations)
    ann_chunks = [c for c in ann_chunks_all if c.chunk_id not in existing_ids]

    if not ann_chunks:
        print(json.dumps({"ok": True, "trace_id": trace_id, "added": 0, "total": len(base_chunks)}, ensure_ascii=False, indent=2))
        raise SystemExit(0)

    model = SentenceTransformer(args.model_name)
    new_vectors = embed_texts(model, [c.text for c in ann_chunks])
    index = load_index_for_append(index_path, int(new_vectors.shape[1]))
    index.add(new_vectors)

    store_dir.mkdir(parents=True, exist_ok=True)
    faiss.write_index(index, str(index_path))
    save_chunks_jsonl(chunks_path, base_chunks + ann_chunks)

    print(json.dumps({"ok": True, "trace_id": trace_id, "added": len(ann_chunks), "total": len(base_chunks) + len(ann_chunks)}, ensure_ascii=False, indent=2))

if __name__ == "__main__":
    main()

解释与自检要点:

对比测试:标注数据追加前 vs 追加后(让学生直观看到“提升是什么”)

目标:用同一批问题做 A/B 对比,让学生看到标注数据带来的 2 类提升:

1)覆盖提升:原来“答不上/命中不稳”的现场问法,追加后能命中到更贴近问题的证据片段。
2)证据提升:命中片段更像“现场样本结论”(条件→结论→理由),便于直接引用回答并追溯来源。

  1. 先用“仅文档库”的版本运行问答工具,记录输出(特别看 hits 是否为空、以及 top-1 的 source/text_preview 是否相关)。
  2. 运行本节的“追加标注数据脚本”写入 ANN-A001~A004
  3. 再用同样的问题运行问答工具,对比两次结果差异。

建议用这些提问来测(同一个问题连续跑 2 次:追加前、追加后):

# 说明:如果你发现“追加后也没命中”,先把阈值临时降到 0.20 做验证,确认命中后再调回 0.45 做稳定性优化

cd .\12_kb_qa_project

# 1)样本级问法(标注数据最擅长覆盖):追加后通常能命中 source=apple_sorting_annotations,text_preview 出现 fruit_id=A004
py .\qa_tool.py --query "fruit_id=A004 为什么要复检?" --top-k 10 --score-threshold 0.20 --store-dir ".\faiss_store"
py .\qa_tool.py --query "A004 霉斑 复检口" --top-k 10 --score-threshold 0.20 --store-dir ".\faiss_store"

# 2)字段组合问法(更贴近现场):追加后更容易命中“条件→结论→理由”样本片段
py .\qa_tool.py --query "果径80mm+霉斑 最终去向是什么?" --top-k 10 --score-threshold 0.20 --store-dir ".\faiss_store"

# 3)规则编号问法(追加后会多出“样本视角”的证据):命中 rule_id=R-APPLE-EX 的样本片段
py .\qa_tool.py --query "R-APPLE-EX 对应的处理是什么?" --top-k 10 --score-threshold 0.20 --store-dir ".\faiss_store"

你们要观察的“差异点”(写进实验记录):

目标:找一个“能答但不乱答”的阈值。

2)回答逻辑升级(不靠“更会说”,靠“更可验”)

建议升级点(按优先级):

  1. 先回答结论,再给证据:输出“结论 + 引用来源(source/doc_type/section/rule_id/alarm_code/fruit_id)”

  2. 命中多个来源时:展示 top-3 的简短摘要,提示“可继续追问”缩小范围

  3. 无结果时:返回“需要澄清的问题清单”(例如缺少设备型号/告警码/条件字段)

  4. 知识库加载失败(可类比“数据库连接失败”):

  1. 检索无结果:

课程思政融入点(地方产业项目导向:质量与责任)

大模型任务(教师可发放,学生可复用)

作业(按要求布置)

1)提交知识库检索工具完整代码(含标注数据融合逻辑),附功能说明。

2)提交测试用例执行报告(含各用例执行结果截图),附检索效果优化说明。

3)撰写 200 字左右分析,说明标注数据对知识库完整性及检索准确性的提升作用。

参考与延伸